/*
 * MIT License
 *
 * Copyright (c) 2023 北京凯特伟业科技有限公司
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
package com.je.ibatis.extension.plugins.inner;

import com.je.ibatis.extension.metadata.model.Table;
import com.je.ibatis.extension.plugins.inner.tenant.DynaTenant;
import com.je.ibatis.extension.plugins.inner.tenant.DynaTenantContext;
import com.je.ibatis.extension.util.PluginUtils;
import com.je.ibatis.session.CustomConfiguration;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.ValueListExpression;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.springframework.util.StringUtils;

import java.sql.Connection;
import java.util.Map;
import java.util.Properties;

/**
 * 内部租户插件
 */
public class TenantInnerInterceptor extends AbstractInnerInterceptor {

    /**
     * 全局数据，默认为公共数据，所有租户是可用的
     */
    protected static final String GLOBAL_TENANT = "global";
    /**
     * 系统租户
     */
    protected static final String SYSTEM_TENANT = "system";

    protected final String MARK = "tenant";

    /**
     * 表租户ID字段
     */
    private String tableTenantIdField = "SY_TENANT_ID";
    private String tableTenantNameField = "SY_TENANT_NAME";
    /**
     * 模型租户ID成员变量
     */
    private String modelTenantIdField = "tenantId";
    private String modelTenantNameField = "tenantName";

    public TenantInnerInterceptor() {
        this.enable = false;
    }

    public String getTableTenantIdField() {
        return tableTenantIdField;
    }

    public void setTableTenantIdField(String tableTenantIdField) {
        this.tableTenantIdField = tableTenantIdField;
    }

    public String getTableTenantNameField() {
        return tableTenantNameField;
    }

    public void setTableTenantNameField(String tableTenantNameField) {
        this.tableTenantNameField = tableTenantNameField;
    }

    public String getModelTenantIdField() {
        return modelTenantIdField;
    }

    public void setModelTenantIdField(String modelTenantIdField) {
        this.modelTenantIdField = modelTenantIdField;
    }

    public String getModelTenantNameField() {
        return modelTenantNameField;
    }

    public void setModelTenantNameField(String modelTenantNameField) {
        this.modelTenantNameField = modelTenantNameField;
    }

    @Override
    public String getMark() {
        return MARK;
    }

    @Override
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) throws JSQLParserException {
        PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
        MappedStatement ms = mpSh.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
        if (sct == SqlCommandType.SELECT || sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
            if (ms.getSqlCommandType() == SqlCommandType.INSERT) {
                proceedInsert(mpSh, ms, mpBs.getDelegate());
            } else if (ms.getSqlCommandType() == SqlCommandType.UPDATE) {
                proceedUpdate(ms, mpBs.getDelegate());
            } else if (ms.getSqlCommandType() == SqlCommandType.DELETE) {
                proceedDelete(ms, mpBs.getDelegate());
            } else if (ms.getSqlCommandType() == SqlCommandType.SELECT) {
                proceedSelect(ms, mpBs.getDelegate());
            }
        }
    }

    /**
     * 进行select处理
     *
     * @param ms
     * @param boundSql
     * @return
     */
    protected void proceedSelect(MappedStatement ms, BoundSql boundSql) throws JSQLParserException {
        CustomConfiguration configuration = (CustomConfiguration) ms.getConfiguration();

        //如果是按照sql查询
        if (StringUtils.isEmpty(getTableTenantIdField())) {
            return;
        }
        if (!boundSql.getSql().trim().startsWith("select") && !boundSql.getSql().trim().startsWith("SELECT")) {
            return;
        }
        Select select = (Select) CCJSqlParserUtil.parse(boundSql.getSql());
        SelectBody selectBody = select.getSelectBody();
        PlainSelect plainSelect = (PlainSelect) selectBody;

        String tableCode = null;
        FromItem fromItem = plainSelect.getFromItem();
        if (fromItem instanceof net.sf.jsqlparser.schema.Table) {
            net.sf.jsqlparser.schema.Table fromTable = (net.sf.jsqlparser.schema.Table) fromItem;
            tableCode = fromTable.getName();
        }

        if (StringUtils.isEmpty(tableCode)) {
            return;
        }

        Table table = configuration.getMetaStatementBuilder().getCacheManager().getTable(tableCode);
        if (table == null) {
            table = configuration.getMetaStatementBuilder().table(tableCode);
        }
        //是否包含租户字段
        if (table.getColumns().get(getTableTenantIdField()) == null) {
            return;
        }

        //进行where改写
        Expression whereExpression = plainSelect.getWhere();
        if (whereExpression == null) {
            plainSelect.setWhere(buildTenantSelectWhere(plainSelect));
        } else {
            if (!checkWhereHasTenantField(whereExpression)) {
                AndExpression andExpression = new AndExpression();
                andExpression.withLeftExpression(whereExpression);
                andExpression.withRightExpression(buildTenantSelectWhere(plainSelect));
                plainSelect.setWhere(andExpression);
            }
        }
        PluginUtils.mpBoundSql(boundSql).sql(plainSelect.toString());
    }

    private boolean checkWhereHasTenantField(Expression whereExpression) {
        if (whereExpression instanceof EqualsTo) {
            EqualsTo equalsTo = (EqualsTo) whereExpression;
            if (equalsTo.getLeftExpression() instanceof Column) {
                Column column = (Column) equalsTo.getLeftExpression();
                if (column.getColumnName().equals(getTableTenantIdField())) {
                    return true;
                }
            }
        } else if (whereExpression instanceof NotEqualsTo) {
            NotEqualsTo notEqualsTo = (NotEqualsTo) whereExpression;
            if (notEqualsTo.getLeftExpression() instanceof Column) {
                Column column = (Column) notEqualsTo.getLeftExpression();
                if (column.getColumnName().equals(getTableTenantIdField())) {
                    return true;
                }
            }
        } else if (whereExpression instanceof InExpression) {
            InExpression inExpression = (InExpression) whereExpression;
            if (inExpression.getLeftExpression() instanceof Column) {
                Column column = (Column) inExpression.getLeftExpression();
                if (column.getColumnName().equals(getTableTenantIdField())) {
                    return true;
                }
            }
        }
        
        if (whereExpression instanceof AndExpression) {
            AndExpression andExpression = (AndExpression) whereExpression;
            if (checkWhereHasTenantField(andExpression.getLeftExpression())) {
                return true;
            }
            if (checkWhereHasTenantField(andExpression.getRightExpression())) {
                return true;
            }
        } else if (whereExpression instanceof OrExpression) {
            OrExpression orExpression = (OrExpression) whereExpression;
            if (checkWhereHasTenantField(orExpression.getLeftExpression())) {
                return true;
            }
            if (checkWhereHasTenantField(orExpression.getRightExpression())) {
                return true;
            }
        }

        return false;
    }

    /**
     * 进行insert处理
     *
     * @param ms
     * @param boundSql
     * @return
     */
    protected void proceedInsert(PluginUtils.MPStatementHandler mpSh, MappedStatement ms, BoundSql boundSql) throws JSQLParserException {
        //获取当前上下文租户
        DynaTenant tenant = DynaTenantContext.getCurrentTenant();
        if (tenant == null) {
            throw new RuntimeException("The insert contains tenant field, but can't find the tenant in the tenant context!");
        }

        CustomConfiguration configuration = (CustomConfiguration) ms.getConfiguration();
        Insert insert = (Insert) CCJSqlParserUtil.parse(boundSql.getSql());
        String tableCode = insert.getTable().getName();
        if (StringUtils.isEmpty(tableCode)) {
            return;
        }

        Table table = configuration.getMetaStatementBuilder().getCacheManager().getTable(tableCode);
        if (table == null) {
            table = configuration.getMetaStatementBuilder().table(tableCode);
        }
        //是否包含租户字段
        if (table.getColumns().get(getTableTenantIdField()) == null) {
            return;
        }

        //如果包含租户字段，则不再添加
        for (Column eachColumn : insert.getColumns()) {
            if (getTableTenantIdField().equals(eachColumn.getColumnName())) {
                return;
            }
        }

        if (insert.getSelect() == null) {
            //insert select
            Integer tenantColumnIndex = null;
            for (int i = 0; i < insert.getColumns().size(); i++) {
                if (insert.getColumns().get(i).getColumnName().equals(getTableTenantIdField())) {
                    tenantColumnIndex = i;
                    break;
                }
            }

            if (tenantColumnIndex == null) {
                Column idColumn = new Column(null, getTableTenantIdField());
                Column nameColumn = new Column(null, getTableTenantNameField());
                insert.addColumns(idColumn, nameColumn);
                recursiveSetTenantColumnValue(insert.getItemsList());
            } else {
                Object params = mpSh.parameterHandler().getParameterObject();
                if (params instanceof Map) {
                    Map<String, Object> mapParams = (Map<String, Object>) params;
                    mapParams.put(getTableTenantIdField(), DynaTenantContext.getCurrentTenant().getId());
                    mapParams.put(getTableTenantNameField(), DynaTenantContext.getCurrentTenant().getName());
                }
            }
            PluginUtils.mpBoundSql(boundSql).sql(insert.toString());
        } else {

        }
    }

    /**
     * 进行更新处理
     *
     * @param ms
     * @param boundSql
     * @return
     */
    protected void proceedUpdate(MappedStatement ms, BoundSql boundSql) throws JSQLParserException {

        //获取当前上下文租户
        DynaTenant tenant = DynaTenantContext.getCurrentTenant();
        if (tenant == null) {
            throw new RuntimeException("The update contains tenant field, but can't find the tenant in the tenant context!");
        }

        if (!boundSql.getSql().trim().startsWith("update") && !boundSql.getSql().trim().startsWith("UPDATE")) {
            return;
        }

        CustomConfiguration configuration = (CustomConfiguration) ms.getConfiguration();
        Update update = (Update) CCJSqlParserUtil.parse(boundSql.getSql());
        String tableCode = update.getTable().getName();
        if (StringUtils.isEmpty(tableCode)) {
            return;
        }

        Table table = configuration.getMetaStatementBuilder().getCacheManager().getTable(tableCode);
        if (table == null) {
            table = configuration.getMetaStatementBuilder().table(tableCode);
        }
        //是否包含租户字段
        if (table.getColumns().get(getTableTenantIdField()) == null) {
            return;
        }

        Expression whereExpression = update.getWhere();
        if (whereExpression == null) {
            update.setWhere(buildTenantUpdateWhere(update));
        } else {
            if (!checkWhereHasTenantField(whereExpression)) {
                AndExpression andExpression = new AndExpression();
                andExpression.withLeftExpression(whereExpression);
                andExpression.withRightExpression(buildTenantUpdateWhere(update));
                update.setWhere(andExpression);
            }
        }
        PluginUtils.mpBoundSql(boundSql).sql(update.toString());
    }

    /**
     * 删除delete处理
     *
     * @param ms
     * @param boundSql
     * @return
     */
    protected void proceedDelete(MappedStatement ms, BoundSql boundSql) throws JSQLParserException {

        //获取当前上下文租户
        DynaTenant tenant = DynaTenantContext.getCurrentTenant();
        if (tenant == null) {
            throw new RuntimeException("The delete contains tenant field, but can't find the tenant in the tenant context!");
        }

        if (!boundSql.getSql().trim().startsWith("delete") && !boundSql.getSql().trim().startsWith("DELETE")) {
            return;
        }

        CustomConfiguration configuration = (CustomConfiguration) ms.getConfiguration();
        Delete delete = (Delete) CCJSqlParserUtil.parse(boundSql.getSql());
        String tableCode = delete.getTable().getName();
        if (StringUtils.isEmpty(tableCode)) {
            return;
        }

        Table table = configuration.getMetaStatementBuilder().getCacheManager().getTable(tableCode);
        if (table == null) {
            table = configuration.getMetaStatementBuilder().table(tableCode);
        }
        //是否包含租户字段
        if (table.getColumns().get(getTableTenantIdField()) == null) {
            return;
        }

        Expression whereExpression = delete.getWhere();
        if (whereExpression == null) {
            delete.setWhere(buildTenantDeleteWhere(delete));
        } else {
            if (!checkWhereHasTenantField(whereExpression)) {
                AndExpression andExpression = new AndExpression();
                andExpression.withLeftExpression(whereExpression);
                andExpression.withRightExpression(buildTenantDeleteWhere(delete));
                delete.setWhere(andExpression);
            }
        }
        PluginUtils.mpBoundSql(boundSql).sql(delete.toString());
    }

    private Expression buildTenantUpdateWhere(Update update) {
        net.sf.jsqlparser.schema.Table table = update.getTable();
        StringBuilder sb = new StringBuilder();
        if (table.getAlias() == null || StringUtils.isEmpty(table.getAlias().getName())) {
            sb.append(getTableTenantIdField());
        } else {
            sb.append(table.getAlias().getName());
            sb.append(".");
            sb.append(getTableTenantIdField());
        }

        Column column = new Column(null, sb.toString());
        if (SYSTEM_TENANT.equals(DynaTenantContext.getCurrentTenant().getId())) {
            InExpression inExpression = new InExpression();
            inExpression.setLeftExpression(column);
            ExpressionList valueList = new ExpressionList();
            valueList.addExpressions(new StringValue(GLOBAL_TENANT), new StringValue(DynaTenantContext.getCurrentTenant().getId()));
            ValueListExpression listExpression = new ValueListExpression();
            listExpression.setExpressionList(valueList);
            inExpression.setRightExpression(listExpression);
            return inExpression;
        } else {
            EqualsTo equalsTo = new EqualsTo();
            equalsTo.setLeftExpression(column);
            equalsTo.setRightExpression(new StringValue(DynaTenantContext.getCurrentTenant().getId()));
            return equalsTo;
        }
    }

    private Expression buildTenantSelectWhere(PlainSelect select) {
        FromItem fromItem = select.getFromItem();
        StringBuilder sb = new StringBuilder();
        if (fromItem.getAlias() == null || StringUtils.isEmpty(fromItem.getAlias().getName())) {
            sb.append(getTableTenantIdField());
        } else {
            sb.append(fromItem.getAlias().getName());
            sb.append(DOT);
            sb.append(getTableTenantIdField());
        }

        Column column = new Column(null, sb.toString());
        if (DynaTenantContext.getCurrentTenant() == null) {
            EqualsTo equalsTo = new EqualsTo();
            equalsTo.setLeftExpression(column);
            equalsTo.setRightExpression(new StringValue(GLOBAL_TENANT));
            return equalsTo;
        } else {
            InExpression inExpression = new InExpression();
            inExpression.setLeftExpression(column);
            ExpressionList valueList = new ExpressionList();
            valueList.addExpressions(new StringValue(GLOBAL_TENANT), new StringValue(DynaTenantContext.getCurrentTenant().getId()));
            ValueListExpression listExpression = new ValueListExpression();
            listExpression.setExpressionList(valueList);
            inExpression.setRightExpression(listExpression);
            return inExpression;
        }
    }

    private Expression buildTenantDeleteWhere(Delete delete) {
        net.sf.jsqlparser.schema.Table table = delete.getTable();
        StringBuilder sb = new StringBuilder();
        if (table.getAlias() == null || StringUtils.isEmpty(table.getAlias().getName())) {
            sb.append(getTableTenantIdField());
        } else {
            sb.append(table.getAlias().getName());
            sb.append(DOT);
            sb.append(getTableTenantIdField());
        }

        Column column = new Column(null, sb.toString());
        if (SYSTEM_TENANT.equals(DynaTenantContext.getCurrentTenant().getId())) {
            InExpression inExpression = new InExpression();
            inExpression.setLeftExpression(column);
            ExpressionList valueList = new ExpressionList();
            valueList.addExpressions(new StringValue(GLOBAL_TENANT), new StringValue(DynaTenantContext.getCurrentTenant().getId()));
            ValueListExpression listExpression = new ValueListExpression();
            listExpression.setExpressionList(valueList);
            inExpression.setRightExpression(listExpression);
            return inExpression;
        } else {
            EqualsTo equalsTo = new EqualsTo();
            equalsTo.setLeftExpression(column);
            equalsTo.setRightExpression(new StringValue(DynaTenantContext.getCurrentTenant().getId()));
            return equalsTo;
        }

    }

    private void recursiveSetTenantColumnValue(ItemsList itemsList) {
        if (itemsList instanceof ExpressionList) {
            ExpressionList expressionList = (ExpressionList) itemsList;
            expressionList.addExpressions(new StringValue(DynaTenantContext.getCurrentTenant().getId()));
            expressionList.addExpressions(new StringValue(DynaTenantContext.getCurrentTenant().getName()));
        } else if (itemsList instanceof MultiExpressionList) {
            MultiExpressionList multiExpressionList = (MultiExpressionList) itemsList;
            for (ExpressionList eachExpressionList : multiExpressionList.getExpressionLists()) {
                recursiveSetTenantColumnValue(eachExpressionList);
            }
        }
    }

    @Override
    public void setProperties(Properties properties) {
        String enableField = getMark() + DOT + ENABLE_FIELD;
        if (properties.containsKey(enableField)) {
            setEnable("1".equals(properties.getProperty(enableField)) || "true".equals(properties.getProperty(enableField)));
        } else {
            setEnable(false);
        }

        String tenantTableIdField = getMark() + getTableTenantIdField();
        if (properties.containsKey(tenantTableIdField)) {
            setTableTenantIdField(properties.getProperty(tenantTableIdField));
        }

        String tenantTableNameField = getMark() + getTableTenantNameField();
        if (properties.containsKey(tenantTableNameField)) {
            setTableTenantNameField(properties.getProperty(tenantTableNameField));
        }

        String tenantModelIdField = getMark() + getModelTenantIdField();
        if (properties.containsKey(tenantModelIdField)) {
            setModelTenantIdField(properties.getProperty(tenantModelIdField));
        }

        String tenantModelNameField = getMark() + getModelTenantIdField();
        if (properties.containsKey(tenantModelNameField)) {
            setModelTenantNameField(properties.getProperty(tenantModelNameField));
        }

    }
}
